import os
import sys
import glob
from osgeo import gdal
import numpy as np
import cv2
import tqdm
import os
os.environ['CPL_ZIP_ENCODING'] = 'UTF-8'
os.environ['PROJ_LIB'] = r'D:\anaconda3\envs\pytorch_GPU\Lib\site-packages\pyproj\proj_dir\share\proj'
os.environ['GDAL_DATA'] = r'D:\anaconda3\envs\pytorch_GPU\Library\share'

def CalHistogram(img):
    img_dtype = img.dtype
    img_hist = img.reshape(-1)
    img_min, img_max = img_hist.min(), img_hist.max()
    n_bins = 2 ** 16
    if (img_dtype == np.uint8):
        n_bins = 256
    if (img_dtype == np.uint16):
        n_bins = 2 ** 16
    elif (img_dtype == np.uint32):
        n_bins = 2 ** 32
    if (img_dtype == np.uint8) or (img_dtype == np.uint16) or (img_dtype == np.uint32):
        hist = np.bincount(img_hist, minlength=n_bins)
        hist[0] = 0
        hist[-1] = 0
        s_values = np.arange(n_bins)
    else:
        hist, s_values = np.histogram(img_hist, bins=n_bins, range=(img_min, img_max))
        hist[0] = 0
        hist[-1] = 0
    img_hist = None
    return hist, s_values


def GetPercentStretchValue(img, left_clip=0.001, right_clip=0.001):
    right_clip = 1.0 - right_clip
    hist, s_values = CalHistogram(img)
    s_quantiles = np.cumsum(hist).astype(np.float64)
    s_quantiles /= (s_quantiles[-1] + 1.0E-5)
    left_clip_index = np.argmin(np.abs(s_quantiles - left_clip))
    right_clip_index = np.argmin(np.abs(s_quantiles - right_clip))
    img_min_clip, img_max_clip = s_values[[left_clip_index, right_clip_index]]
    return img_min_clip, img_max_clip


def percent_stretch_image(input_image_data, left_clip=0.001, right_clip=0.001, left_mask=None,
                          right_mask=None):
    if input_image_data is None:
        return None
    n_dim = input_image_data.ndim
    img_bands = 1 if n_dim == 2 else input_image_data.shape[n_dim - 1]
    xsize = input_image_data.shape[1]
    ysize = input_image_data.shape[0]
    indtype = input_image_data.dtype
    if indtype == np.uint8:
        to_8bit = True
    if img_bands > 1:
        out_8bit_data = np.zeros((ysize, xsize, img_bands), dtype=np.uint8)
    else:
        out_8bit_data = np.zeros((ysize, xsize), dtype=np.uint8)
    for i_band in range(img_bands):
        if img_bands == 1:
            input_image_data_raw = input_image_data  # [:,:,i_band]
        else:
            input_image_data_raw = input_image_data[:, :, i_band]
        img_clip_min, img_clip_max = GetPercentStretchValue(input_image_data_raw, left_clip=left_clip,
                                                            right_clip=right_clip)
        input_image_data_raw = np.clip(input_image_data_raw, img_clip_min, img_clip_max)
        input_image_data_raw = (input_image_data_raw - img_clip_min) / (img_clip_max - img_clip_min) * 255
        input_image_data_raw = input_image_data_raw.astype(np.uint8)
        if img_bands > 1:
            out_8bit_data[:, :, i_band] = input_image_data_raw
        else:
            out_8bit_data = input_image_data_raw
    return out_8bit_data


def read_tif(file_path):
    tif_f = file_path
    ds = gdal.Open(tif_f)
    if ds == None:
        print("Error || Can't open {0} as tif file.".format(tif_f))
        return
    cols = ds.RasterXSize
    rows = ds.RasterYSize
    bands = ds.RasterCount
    pro = ds.GetProjection()
    #  获取仿射矩阵信息
    geotrans = ds.GetGeoTransform()
    data_set = np.zeros((rows, cols, bands))
    for i in range(bands):
        band = ds.GetRasterBand(i + 1)
        data_type = gdal.GetDataTypeName(band.DataType).lower()
        data_set[:, :, i] = band.ReadAsArray()
    data_set = np.array(data_set, dtype=data_type)
    del ds
    return data_set, pro


def writeTiff(im_data, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape

    # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), gdal.GDT_Byte)
    if (dataset != None):
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影
    print(f'波段总和{im_bands}')
    for i in tqdm.tqdm(range(im_bands)):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset


if __name__ == '__main__':
    in_file = r"G:\02工作空间\beijing_img\data\GF2_PMS1_E110.7_N21.4_20210223_L1A0005501536_pansharpen.tif"
    img, pro,geotrans = read_tif(in_file)
    n_dim = img.ndim
    img_bands = 1 if n_dim == 2 else img.shape[n_dim - 1]
    print(img.min(), img.mean(), img.max())
    img_raw_s = (img - img.min()) / (img.max() - img.min()) * 255
    print('img raw s:', img_raw_s.min(), img_raw_s.mean(), img_raw_s.max())
    img = percent_stretch_image(img)
    path = r"G:\02工作空间\beijing_img\data\GF2_PMS1_E110.7_N21.4_20210223_L1A0005501536_Line.tif"
    writeTiff(img, geotrans, pro, path)

